{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import warnings\n",
    "from __init__ import list_rois, extract_samples, make_table, load_fit, get_waic_and_loo, get_waic, get_fit_path, get_model_path\n",
    "warnings.simplefilter(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "fits_path = './fits'\n",
    "models_path = '.'\n",
    "#model_names = ['reducedlinearmodelq0', 'reducedlinearmodelq0ctime', 'reducedlinearmodelNegBinom']\n",
    "model_name = 'reducedlinearmodelq0'\n",
    "fit_format = 1\n",
    "\n",
    "params = ['R0', 'car', 'ifr']\n",
    "quantiles = [0.025, 0.25, 0.5, 0.75, 0.975]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 5 ROIs\n"
     ]
    }
   ],
   "source": [
    "model_path = get_model_path(models_path, model_name)\n",
    "extension = ['csv', 'pkl'][fit_format]\n",
    "rois = list_rois(fits_path, model_name, extension)\n",
    "print(\"There are %d ROIs\" % len(rois))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5/5 [00:34<00:00,  5.30s/it]\n"
     ]
    }
   ],
   "source": [
    "dfs = []\n",
    "\n",
    "for roi in tqdm(rois):\n",
    "    fit_path = get_fit_path(fits_path, model_name, roi)\n",
    "    if fit_format==1:\n",
    "        fit = load_fit(fit_path, model_path)\n",
    "        stats = get_waic_and_loo(fit)\n",
    "        #print(stats)\n",
    "        samples = fit.to_dataframe()\n",
    "    elif fit_format==0:\n",
    "        samples = extract_samples(fits_path, models_path, model_name, roi, fit_format)\n",
    "        stats = get_waic(samples)\n",
    "    df = make_table(roi, samples, params, stats, quantiles=quantiles)\n",
    "    dfs.append(df)\n",
    "\n",
    "df = pd.concat(dfs)\n",
    "df.to_csv('%s_fit_table.csv' % model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The ROIs you want to plot\n",
    "roi_subset = ['Germany', 'US_MI', 'Spain']\n",
    "\n",
    "def plot_table_data(df, roi_subset):\n",
    "    n_params = df.shape[1]\n",
    "    fig, axes = plt.subplots(1, n_params, figsize=(n_params*5, 5))\n",
    "    for i, ax in enumerate(axes.flat):\n",
    "        for j, roi in enumerate(roi_subset):\n",
    "            col = df.columns[i]\n",
    "            boxes = [\n",
    "                {\n",
    "                'x': i,\n",
    "                'label' : roi,\n",
    "                'whislo': df.loc[(roi, np.min(quantiles)), col],    # Bottom whisker position\n",
    "                'q1'    : df.loc[(roi, 0.25), col],    # First quartile (25th percentile)\n",
    "                'med'   : df.loc[(roi, 0.5), col],    # Median         (50th percentile)\n",
    "                'q3'    : df.loc[(roi, 0.75), col],    # Third quartile (75th percentile)\n",
    "                'whishi': df.loc[(roi, np.max(quantiles)), col],    # Top whisker position\n",
    "                'fliers': []        # Outliers\n",
    "                }\n",
    "            ]\n",
    "            ax.bxp(boxes, positions=[j], showfliers=False)\n",
    "        ax.set_title(df.columns[i])\n",
    "        \n",
    "plot_table_data(df, roi_subset)\n",
    "plt.savefig('%s_plot_table.png' % model_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
